Amazon SageMaker Debuggerを試してみた #reinvent
最初に
こんにちはデータアナリティクス事業本部のyoshimです。
re:Invent2019にて「Amazon SageMaker Debugger」のリリースが発表されたので、早速試してみました。
「Amazon SageMaker Debugger」はその名前の通り「自分で書いたモデルのデバッグ」の際に利用します。
自分でモデルを書いていると、色々と意図しないバグが残ってしまいがちですが、開発の最初から「Amazon SageMaker Debugger」を使っていると、バグを減らして、開発時間も短縮できるかもしれません。
(一部記述内容に誤りがあったため、2019年12月6日に修正しました)
目次
1.Amazon SageMaker Debuggerとは
最初に、「Amazon SageMaker Debugger」とはなんのか、についてですがざっくり言うと「トレーニング中のロスや精度等のスカラー値をモニタリングすることで、モデルの開発や評価をサポートするサービス」と言えそうです。
具体的には、「指定したルール」で「トレーニング中のロスや精度等のスカラー値をモニタリング」し、「その結果からアラート(CloudWatch Event)をあげる」、「後から結果の詳細を確認する」といったことが可能です。
また、一般的なフレームワーク(TensorFlow, PyTorch, Apache MXNet, XGBoost)がサポートされており、既存のスクリプトの修正は不要です。
(「TensorFlow Horovod」を使って分散学習をさせたものについては、2019年12月5日時点ではサポートされていないようです。また、python2系はサポート対象外です)
より詳細についてはドキュメントをご参照いただきたいのですが、ざっくり機能の概要としては以上となります。
2.やってみた
ここからは実際に手を動かして理解を深めようと思います。
こちらの内容を実際に手を動かして確認してみました。
以下、スクリプトの内容を処理の段階ごとに区切って確認していきます。
2-1.用意されているスクリプト、デバッグのルールを使ってモデルの学習をする
サンプルとして用意されていたスクリプトやデータを使ってモデルの学習をします。
また、デバッグのルールについては「SageMaker側で事前に用意されているもの(15種類)」と「自前で定義するカスタムルール」の2パターンを利用できるのですが、今回はSageMaker側で事前に用意されているルールを利用しました。
デバッグをするために「entry_point」に指定するスクリプトを修正する必要はなく、下記の通りルールを指定してあげる数行を追加すればOKです。
from sagemaker.debugger import Rule, rule_configs from sagemaker.tensorflow import TensorFlow import sagemaker # define the entrypoint script entrypoint_script='src/mnist_zerocodechange.py' # hyper parameter # sampleのコードよりも増やしてます hyperparameters = { "num_epochs": 12 } # デバッグに適用するルールをList形式で指定 # SageMaker側で事前に用意されているルール、自前で用意するルール、の2パターンを利用可能 # 今回はSageMaker側で事前に用意されているルールを利用する rules = [ Rule.sagemaker(rule_configs.vanishing_gradient()), Rule.sagemaker(rule_configs.loss_not_decreasing()) ] ## estimatorに「rules」変数として指定する estimator = TensorFlow( role=sagemaker.get_execution_role(), base_job_name='smdebugger-demo-mnist-tensorflow', train_instance_count=1, train_instance_type='ml.m5.12xlarge', train_volume_size=400, entry_point=entrypoint_script, framework_version='1.15', py_version='py3', train_max_run=3600, script_mode=True, hyperparameters=hyperparameters, ## New parameter rules = rules ) estimator.fit(wait=True)
さて、これで「勾配消失」、「ロスが減らない」といった事象をチェックするルールを指定した状態で学習を実行できました。
(自前で定義したルールを利用する場合は、PythonのスクリプトをS3に格納してパスを指定する必要があります)
以上の形で学習を開始したところ、早速下記のようなログが出力されました。
これから検証する段階なのですが、「InProgress」という表記でした。
学習の終了直後、下記のようなログが出力されました。
ロスを確認したところ、確かにロスが下がらない場合がありました。
また、勾配消失はまだ発生していないので「InProgress」のままです。
2-2.ルールの検証結果を確認
学習が終わったので、今回指定したルールの検証結果を確認しましょう。
勾配消失は発生していないので「NoIssuesFound」、ロスが下がらない現象は発生していたので「IssuesFound」となっています。
estimator.latest_training_job.rule_job_summary()
[{'RuleConfigurationName': 'VanishingGradient', 'RuleEvaluationJobArn': 'arn:aws:sagemaker:us-east-1:123456789012:processing-job/smdebugger-demo-mnist-tens-vanishinggradient-f14fd901', 'RuleEvaluationStatus': 'NoIssuesFound', 'LastModifiedTime': datetime.datetime(2019, 12, 5, 22, 49, 12, 899000, tzinfo=tzlocal())}, {'RuleConfigurationName': 'LossNotDecreasing', 'RuleEvaluationJobArn': 'arn:aws:sagemaker:us-east-1:123456789012:processing-job/smdebugger-demo-mnist-tens-lossnotdecreasing-33cbb36f', 'RuleEvaluationStatus': 'IssuesFound', 'StatusDetails': 'RuleEvaluationConditionMet: Evaluation of the rule LossNotDecreasing at step 4500 resulted in the condition being met\n', 'LastModifiedTime': datetime.datetime(2019, 12, 5, 22, 49, 12, 899000, tzinfo=tzlocal())}]
また、ログがCloudWatchLogsに出力されているので、その参照先URLを出力して、確認できます。
何か問題があった場合はこのリンク先から詳細を確認できます。
def _get_rule_job_name(training_job_name, rule_configuration_name, rule_job_arn): """Helper function to get the rule job name with correct casing""" return "{}-{}-{}".format( training_job_name[:26], rule_configuration_name[:26], rule_job_arn[-8:] ) def _get_cw_url_for_rule_job(rule_job_name, region): return "https://{}.console.aws.amazon.com/cloudwatch/home?region={}#logStream:group=/aws/sagemaker/ProcessingJobs;prefix={};streamFilter=typeLogStreamPrefix".format(region, region, rule_job_name) def get_rule_jobs_cw_urls(estimator): region = boto3.Session().region_name training_job = estimator.latest_training_job training_job_name = training_job.describe()["TrainingJobName"] rule_eval_statuses = training_job.describe()["DebugRuleEvaluationStatuses"] result={} for status in rule_eval_statuses: if status.get("RuleEvaluationJobArn", None) is not None: rule_job_name = _get_rule_job_name(training_job_name, status["RuleConfigurationName"], status["RuleEvaluationJobArn"]) result[status["RuleConfigurationName"]] = _get_cw_url_for_rule_job(rule_job_name, region) return result get_rule_jobs_cw_urls(estimator)
{'VanishingGradient': 'https://us-east-1.console.aws.amazon.com/cloudwatch/home?region=us-east-1#logStream:group=/aws/sagemaker/ProcessingJobs;prefix=smdebugger-demo-mnist-tens-VanishingGradient-f14fd901;streamFilter=typeLogStreamPrefix', 'LossNotDecreasing': 'https://us-east-1.console.aws.amazon.com/cloudwatch/home?region=us-east-1#logStream:group=/aws/sagemaker/ProcessingJobs;prefix=smdebugger-demo-mnist-tens-LossNotDecreasing-33cbb36f;streamFilter=typeLogStreamPrefix'}
2-3.探索的に確認する
また、より探索的に結果を確認することもできます。
from smdebug.trials import create_trial trial = create_trial(estimator.latest_job_debugger_artifacts_path()) # 保持している全てのテンソルをリストとして表示する # どれを可視化するか、というのを選ぶ際にまずこれで確認 print(trial.tensor_names())
%matplotlib inline import matplotlib.pyplot as plt import re # 今回はロスを可視化する trial.tensor_names(collection="losses") # Define a function that, for the given tensor name, walks through all # the iterations for which we have data and fetches the value. # Returns the set of steps and the values def get_data(trial, tname): tensor = trial.tensor(tname) steps = tensor.steps() vals = [tensor.value(s) for s in steps] return steps, vals def plot_tensors(trial, collection_name, ylabel=''): """ Takes a `trial` and plots all tensors that match the given regex. """ plt.figure( num=1, figsize=(8, 8), dpi=80, facecolor='w', edgecolor='k') tensors = trial.tensor_names(collection=collection_name) for tensor_name in sorted(tensors): steps, data = get_data(trial, tensor_name) plt.plot(steps, data, label=tensor_name) plt.legend(bbox_to_anchor=(1.04,1), loc='upper left') plt.xlabel('Iteration') plt.ylabel(ylabel) plt.show() plot_tensors(trial, "losses", ylabel="Loss")
ロスが可視化できました。
確かに、ロスが下がっていないところがあるのが確認できます。
3.まとめ
自分でスクリプトを組んでいると、どうしても「今までの経験からの勘」とかで対応してしまいがちですが、こういった部分をツールがサポートしてくれるのは嬉しいですね。
スクリプトの変更も不要で、「Amazon SageMaker Debugger」が自前で用意しているルールを使うだけなら本当に簡単に適用できるので、開発の初期段階から積極的に使っていきたいですね。
また、今回はノートブックインスタンス上からノートブックファイルを実行したのですが、「Amazon SageMaker Studio」上から実行した時の見え方も今後確認してみたいと思います。